Skip to content

[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860

Merged
phu0ngng merged 3 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_bf16_fix_tols
Apr 9, 2026
Merged

[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860
phu0ngng merged 3 commits intoNVIDIA:mainfrom
phu0ngng:cgemm_bf16_fix_tols

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented Apr 9, 2026

Description

atol=1e-5 was too strict for BF16 comparisons between the NONE collective GEMM and Collective GEMM with RS collective paths. Both paths split K across TP ranks and produce identical BF16 partial GEMMs, but reduce them in different orders:

  • NONE (NCCL all-reduce): ((p0+p1)+(p2+p3)) — binary tree in FP32 → BF16
  • RS (reduce_bf16 kernel): ((p0+p1)+p2)+p3 — sequential in FP32 → BF16

Different reduction associativity causes rounding differences of up to 1 BF16 ULP of the partial GEMM magnitude. The combined tolerance atol + rtol*|ref| covers this across all output scales:

  • Large outputs (|ref| > atol/rtol = 12.5): rtol=1e-2 dominates and provides sufficient coverage.
  • Near-zero outputs: rtol provides no coverage, so atol=0.125 (2× the worst-case 1-ULP diff at O(8) scale) is needed. atol=1e-5 failed because it is far below 1 ULP at any realistic activation magnitude.

Reproducer

The mismatch is verified by a standalone test (https://gist.github.com/phu0ngng/9600caf76df6040ecc4b3f3c6ea20882) that mimics the two collective paths on a single GPU:

  1. Generate TP=4 BF16 partial GEMMs matching the per-rank GEMM size fromtest_gemm.py (M=8192, K_tp=1024, N=16384, seed=PRNGKey(0)).
  2. Reduce via NCCL binary-tree order (C_none) and TE sequential order (C_rs).
  3. Compare element-wise in BF16 ULPs.
$ CUDA_VISIBLE_DEVICES=1 python test_gemm_reduction.py
M=8192 K_tp=1024 N=16384 TP=4: 2 diffs, max=0.5000, max_ulps=1.00 PASS

2 elements differ by exactly 1 BF16 ULP.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Increase atol from 1e-5 to 0.125 to cover the near-zero regime where rtol provides no coverage. Large-magnitude diffs (the common case) are already handled by rtol=1e-2.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng marked this pull request as ready for review April 9, 2026 03:13
@phu0ngng phu0ngng requested a review from ptrendx April 9, 2026 03:13
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

phu0ngng commented Apr 9, 2026

/te-ci JAX L0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR increases atol from 1e-5 to 0.125 for the CGEMM + Reduce-Scatter + BF16 test path to account for the different floating-point reduction orders used by NCCL (binary-tree) and TE's reduce_bf16 kernel (sequential left-to-right). The math is sound: near-zero outputs where rtol provides no coverage can differ by up to 1 BF16 ULP (~0.0625 at O(8) activation scale), making the previous 1e-5 bound far too tight.

Confidence Score: 5/5

Safe to merge — the change is a well-justified tolerance relaxation for a single test path, backed by a mathematical derivation and a standalone reproducer.

Only one file changes, the logic is correct (condition precisely identifies RS+BF16 without quantization), the new atol=0.125 is derived from first principles (2× the worst-case 1-ULP BF16 difference at O(8) scale), and assert_allclose correctly handles both-provided vs. both-None paths. No functional code is modified.

No files require special attention.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
examples/jax/collective_gemm/test_gemm.py Adds a targeted tolerance override (rtol=1e-2, atol=0.125) only for the CGEMM+RS+BF16 path, with a detailed comment explaining the reduction-order mismatch; all other paths continue to use dtype-default tolerances.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[run_gemm_tests] --> B{enable_result_check\nand process_id == 0?}
    B -- No --> Z[Skip check]
    B -- Yes --> C{collective_op == REDUCE_SCATTER\nand not use_quantization?}
    C -- Yes\nis_cgemm_rs_bf16=True --> D["rtol = 1e-2\natol = 0.125\n(covers 1 BF16 ULP near-zero)"]
    C -- No\nis_cgemm_rs_bf16=False --> E["rtol = None\natol = None\n(use dtype defaults)"]
    D --> F[assert_allclose\ngathered_ref_output vs gathered_output]
    E --> F
    F --> G{Both rtol and atol\nnot None?}
    G -- Yes --> H["Use provided\nrtol=1e-2, atol=0.125"]
    G -- No --> I["Fall back to\ndtype_tols(bfloat16)\nrtol=1e-2, atol=1e-5"]
    H --> J[np.testing.assert_allclose]
    I --> J
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into cgemm_bf16_fix_..." | Re-trigger Greptile

@phu0ngng phu0ngng merged commit ac73538 into NVIDIA:main Apr 9, 2026
9 of 12 checks passed
@phu0ngng phu0ngng deleted the cgemm_bf16_fix_tols branch April 9, 2026 22:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants